#!/usr/bin/python

import json
import random
import pickle
import time
import math
import copy
import os.path
import pandas as pd
import numpy as np
import argparse

from data.data import Data
from data.util import circular_mean
from node.node import Node

def is_in_bounds(bounds, value):
    for bound in bounds:
        if bound[0] > bound[1]:
            if bound[0] <= value <= 360.0 or 0.0 <= value < bound[1]:
                return True
        elif bound[0] == 0.0 and value == 360:
            return True
        else:
            if bound[0] <= value < bound[1]:
                return True

    return False

def tree_eval(node, row):
    result = None

    def eval(node, row):
        nonlocal result

        if node.split_var is None:
            if node.data.class_var['type'] == 'lin':
                result = np.mean(node.data.df[node.data.class_var['name']].values)
            elif node.data.class_var['type'] == 'cir':
                result = circula_mean(node.data.df[node.data.class_var['name']].values)
            else:
                raise Exception("Class variable has an unkown type:", node.data.class_var['type'], "has to be either ['lin', 'cir']")

        else:
            if is_in_bounds(node.left_child.data.var_desc[node.split_var]["bounds"], row[node.split_var]):
                eval(node.left_child, row)
            elif is_in_bounds(node.right_child.data.var_desc[node.split_var]["bounds"], row[node.split_var]):
                eval(node.right_child, row)
            else:
                print(node.data.var_desc[node.split_var]["bounds"], row[node.split_var])
                print(node.left_child.data.var_desc[node.split_var]["bounds"], row[node.split_var])
                print(node.right_child.data.var_desc[node.split_var]["bounds"], row[node.split_var])
                sys.exit("tree_eval() problem with bounds ????")

    eval(node, row)

    return result

def tree_mae_calc(node, df):
    acc = 0.0
    total_len = len(df.index)

    for _, row in df.iterrows():
        acc += abs(tree_eval(node, row) - row[node.data.class_var])

    return acc / total_len

def tree_rmse_calc(node, df):
    acc = 0.0
    total_len = len(df.index)

    for _, row in df.iterrows():
        acc += math.pow((tree_eval(node, row) - row[node.data.class_var]), 2)

    return math.sqrt(acc / total_len)


def cxval_k_folds_split(df, k_folds, seed):

    random.seed(seed)

    dataframes = []
    group_size = int(round(df.shape[0]*(1.0/k_folds)))

    for i in range(k_folds-1):
        rows = random.sample(list(df.index), group_size)
        dataframes.append(df.ix[rows])
        df = df.drop(rows)

    dataframes.append(df)

    return dataframes


def cxval_select_fold(i_fold, df_folds):
    df_folds_copy = copy.deepcopy(df_folds)

    if 0 <= i_fold < len(df_folds):
        test_df = df_folds_copy[i_fold]
        del df_folds_copy[i_fold]
        train_df = pd.concat(df_folds_copy)
        return train_df, test_df

    else:
        raise Exception('Group not in range!')

def cxval_test(df, class_var, var_desc, leaf_size, k_folds=5, seed=1):
    df_folds = cxval_k_folds_split(df, k_folds, seed)
    rmse_results = []
    mae_results = []

    for i in range(k_folds):
        train_df, test_df = cxval_select_fold(i, df_folds)
        tree = tree_trainer(train_df, class_var, var_desc, leaf_size)

        mae_results.append(tree_mae_calc(tree, test_df))
        rmse_results.append(tree_rmse_calc(tree, test_df))

    return sum(mae_results)/k_folds, sum(rmse_results)/k_folds

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='AeroMetTree model training command line utility.')
    parser.add_argument('--data', dest='data_path', help='Path to the CSV file containing the data to train the tree', required=True)
    parser.add_argument('--model', dest='model_path', help='Path to the .save model', required=True)

    args = parser.parse_args()

    tree = pickle.load(open(args.model_path, "rb"))
    df = pd.read_csv(args.data_path).dropna()
    
    print("metar,model") 
    for _, row in df.iterrows():
	    print("{0},{1:0.4f}".format(row[tree.data.class_var['name']], tree_eval(tree, row)))
